/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "StackVirtual.h"

namespace aiengine {

StackVirtual::StackVirtual() {

	name("Virtual network stack");

	// Add the specific Protocol objects
	addProtocol(eth, true);
	addProtocol(vlan, false);
	addProtocol(mpls, false);
	addProtocol(pppoe, false);
	addProtocol(ip_, true);
	addProtocol(gre_, true);
	addProtocol(udp_, true);
	addProtocol(vxlan_, true);
	addProtocol(eth_vir_, true);
	addProtocol(ip_vir_, true);
	addProtocol(tcp_vir_, true);
	addProtocol(udp_vir_, true);
	addProtocol(icmp_, true);

        addProtocol(http);
        addProtocol(ssl);
        addProtocol(smtp);
        addProtocol(imap);
        addProtocol(pop);
        addProtocol(bitcoin);
        addProtocol(tcp_generic);
        addProtocol(freqs_tcp, false);
        addProtocol(dns);
        addProtocol(sip);
        addProtocol(dhcp);
        addProtocol(ntp);
        addProtocol(snmp);
        addProtocol(ssdp);
        addProtocol(rtp);
        addProtocol(quic);
        addProtocol(udp_generic);
        addProtocol(freqs_udp, false);

	// The physic FlowManager have a 24 hours timeout
	flow_table_udp_->setTimeout(86400);

	flow_table_udp_->setFlowCache(flow_cache_udp_);
	flow_table_udp_vir_->setFlowCache(flow_cache_udp_vir_);
	flow_table_tcp_vir_->setFlowCache(flow_cache_tcp_vir_);

	// Configure the protocols with the Multiplexers and FlowForwarders
	ip_->setMultiplexer(mux_ip);
	mux_ip->setProtocol(static_cast<ProtocolPtr>(ip_));

        gre_->setMultiplexer(mux_gre_);
        mux_gre_->setProtocol(static_cast<ProtocolPtr>(gre_));

        udp_->setMultiplexer(mux_udp_);
        mux_udp_->setProtocol(static_cast<ProtocolPtr>(udp_));
        ff_udp_->setProtocol(static_cast<ProtocolPtr>(udp_));

        vxlan_->setFlowForwarder(ff_vxlan_);
        vxlan_->setMultiplexer(mux_vxlan_);
        mux_vxlan_->setProtocol(static_cast<ProtocolPtr>(vxlan_));
        ff_vxlan_->setProtocol(static_cast<ProtocolPtr>(vxlan_));

	icmp_->setMultiplexer(mux_icmp_);
	mux_icmp_->setProtocol(static_cast<ProtocolPtr>(icmp_));

	// Configuring the Virtual layers
        eth_vir_->setMultiplexer(mux_eth_vir_);
        mux_eth_vir_->setProtocol(static_cast<ProtocolPtr>(eth_vir_));

        ip_vir_->setMultiplexer(mux_ip_vir_);
        mux_ip_vir_->setProtocol(static_cast<ProtocolPtr>(ip_vir_));

	udp_vir_->setMultiplexer(mux_udp_vir_);
	mux_udp_vir_->setProtocol(static_cast<ProtocolPtr>(udp_vir_));
	ff_udp_vir_->setProtocol(static_cast<ProtocolPtr>(udp_vir_));

	tcp_vir_->setMultiplexer(mux_tcp_vir_);
	mux_tcp_vir_->setProtocol(static_cast<ProtocolPtr>(tcp_vir_));
	ff_tcp_vir_->setProtocol(static_cast<ProtocolPtr>(tcp_vir_));

	// Configure the multiplexers of the physical layers
	mux_eth->addUpMultiplexer(mux_ip);
	mux_ip->addDownMultiplexer(mux_eth);
	mux_ip->addUpMultiplexer(mux_udp_);
	mux_udp_->addDownMultiplexer(mux_ip);
	mux_ip->addUpMultiplexer(mux_gre_); // TODO LUIS
	mux_gre_->addDownMultiplexer(mux_ip);

        // configure the multiplexers of the virtual layers
        mux_gre_->addUpMultiplexer(mux_eth_vir_);
        mux_gre_->addUpMultiplexer(mux_ip_vir_);
        mux_vxlan_->addUpMultiplexer(mux_eth_vir_);

	// TODO: The mux_eth_vir_ should have two mux down
	// but the reference is just for keep the memory under control.
	//
        mux_eth_vir_->addDownMultiplexer(mux_vxlan_);
        mux_eth_vir_->addUpMultiplexer(mux_ip_vir_);
        mux_ip_vir_->addDownMultiplexer(mux_eth_vir_);
        mux_ip_vir_->addUpMultiplexer(mux_icmp_);
        mux_icmp_->addDownMultiplexer(mux_ip_vir_);
        mux_ip_vir_->addUpMultiplexer(mux_udp_vir_);
        mux_udp_vir_->addDownMultiplexer(mux_ip_vir_);
        mux_ip_vir_->addUpMultiplexer(mux_tcp_vir_);
        mux_tcp_vir_->addDownMultiplexer(mux_ip_vir_);

	// Connect the FlowManager and FlowCache
	tcp_vir_->setFlowCache(flow_cache_tcp_vir_);
	tcp_vir_->setFlowManager(flow_table_tcp_vir_);
	flow_table_tcp_vir_->setProtocol(tcp_vir_);

	udp_vir_->setFlowCache(flow_cache_udp_vir_);
	udp_vir_->setFlowManager(flow_table_udp_vir_);
	flow_table_udp_vir_->setProtocol(udp_vir_);

	udp_->setFlowCache(flow_cache_udp_);
	udp_->setFlowManager(flow_table_udp_);
	flow_table_udp_->setProtocol(udp_);

        // Protocols that contains objects should have a reference to the FlowManager
        http->setFlowManager(flow_table_tcp_vir_);
        ssl->setFlowManager(flow_table_tcp_vir_);
        smtp->setFlowManager(flow_table_tcp_vir_);
        imap->setFlowManager(flow_table_tcp_vir_);
        pop->setFlowManager(flow_table_tcp_vir_);

        dns->setFlowManager(flow_table_udp_vir_);
        sip->setFlowManager(flow_table_udp_vir_);
        ssdp->setFlowManager(flow_table_udp_vir_);
        dhcp->setFlowManager(flow_table_udp_vir_);

        freqs_tcp->setFlowManager(flow_table_tcp_vir_);
        freqs_udp->setFlowManager(flow_table_udp_vir_);

	// Connect the AnomalyManager with the protocols that may have anomalies
        ip_->setAnomalyManager(anomaly_);
        tcp_vir_->setAnomalyManager(anomaly_);
        udp_vir_->setAnomalyManager(anomaly_);
        udp_->setAnomalyManager(anomaly_);

	// Configure the FlowForwarders
	udp_->setFlowForwarder(ff_udp_);
	ff_udp_->addUpFlowForwarder(ff_vxlan_);
	vxlan_->setFlowForwarder(ff_vxlan_);
	tcp_vir_->setFlowForwarder(ff_tcp_vir_);
	udp_vir_->setFlowForwarder(ff_udp_vir_);

        setTCPDefaultForwarder(ff_tcp_vir_);
        setUDPDefaultForwarder(ff_udp_vir_);

        std::ostringstream msg;
        msg << name() << " ready.";

        infoMessage(msg.str());

	setMode("full");
}

void StackVirtual::showFlows(std::basic_ostream<char> &out, std::function<bool (const Flow&)> condition, int protocol) const {

        int total = flow_table_tcp_vir_->getTotalFlows() + flow_table_udp_vir_->getTotalFlows();
        total += flow_table_udp_->getTotalFlows();
        out << "Flows on memory " << total << std::endl;

        if (protocol == IPPROTO_TCP)
                flow_table_tcp_vir_->showFlows(out, condition);
        else if (protocol == IPPROTO_UDP) {
                flow_table_udp_->showFlows(out, condition);
                flow_table_udp_vir_->showFlows(out, condition);
        } else {
                flow_table_udp_->showFlows(out, condition);
                flow_table_tcp_vir_->showFlows(out, condition);
                flow_table_udp_vir_->showFlows(out, condition);
	}
}

void StackVirtual::showFlows(Json &out, std::function<bool (const Flow&)> condition, int protocol) const {

        if (protocol == IPPROTO_TCP)
                flow_table_tcp_vir_->showFlows(out, condition);
        else if (protocol == IPPROTO_UDP) {
                flow_table_udp_->showFlows(out, condition);
                flow_table_udp_vir_->showFlows(out, condition);
        } else {
                flow_table_udp_->showFlows(out, condition);
                flow_table_tcp_vir_->showFlows(out, condition);
                flow_table_udp_vir_->showFlows(out, condition);
	}
}

void StackVirtual::statistics(std::basic_ostream<char> &out) const {

        super_::statistics(out);
}

void StackVirtual::setMode(const std::string &mode) {

        std::initializer_list<SharedPointer<FlowForwarder>> tcp_list {
                ff_tcp_vir_, ff_http, ff_ssl, ff_smtp, ff_imap, ff_pop, ff_bitcoin,
                ff_tcp_generic, ff_tcp_freqs};
        std::initializer_list<SharedPointer<FlowForwarder>> udp_list {
                ff_udp_vir_, ff_dns, ff_sip, ff_dhcp, ff_ntp, ff_snmp, ff_ssdp,
                ff_rtp, ff_quic, ff_udp_generic, ff_udp_freqs};

        if (auto it = modes.find(mode); it != modes.end()) {
                std::ostringstream msg;

                disableFlowForwarders(tcp_list);
                disableFlowForwarders(udp_list);

                if ((*it).second == Mode::FULL) {

			operation_mode_ = "full";
                        msg << "Enable FullEngine mode on " << name();
                        enableFlowForwarders(tcp_list);
                        enableFlowForwarders(udp_list);

                } else if ((*it).second == Mode::FREQUENCIES) {

			operation_mode_ = "frequency";
                        msg << "Enable FrequencyEngine mode on " << name();
                        enableFlowForwarders({ff_tcp_vir_, ff_tcp_freqs});
                        enableFlowForwarders({ff_udp_vir_, ff_udp_freqs});

                } else if ((*it).second == Mode::NIDS) {

			operation_mode_ = "nids";
                        msg << "Enable NIDSEngine mode on " << name();
                        enableFlowForwarders({ff_tcp_vir_, ff_tcp_generic});
                        enableFlowForwarders({ff_udp_vir_, ff_udp_generic});
                }
                infoMessage(msg.str());
        }
}

void StackVirtual::setTotalTCPFlows(int value) {

	flow_cache_tcp_vir_->create(value);
	tcp_vir_->createTCPInfos(value);

	// The vast majority of the traffic of internet is HTTP
	// so create 75% of the value received for the http caches
	http->increaseAllocatedMemory(value * 0.75);

	// The 40% of the traffic is SSL
	ssl->increaseAllocatedMemory(value * 0.4);

        // 5% of the traffic could be SMTP/IMAP, im really positive :D
        smtp->increaseAllocatedMemory(value * 0.05);
        imap->increaseAllocatedMemory(value * 0.05);
        pop->increaseAllocatedMemory(value * 0.05);
        bitcoin->increaseAllocatedMemory(value * 0.05);
        freqs_tcp->increaseAllocatedMemory(16);
}

void StackVirtual::setTotalUDPFlows(int value) {

	flow_cache_udp_->create(value/32);
	flow_cache_udp_vir_->create(value);
	dns->increaseAllocatedMemory(value/ 2);

        // SIP values
        sip->increaseAllocatedMemory(value * 0.2);
        ssdp->increaseAllocatedMemory(value * 0.2);
        dhcp->increaseAllocatedMemory(value * 0.1);
        freqs_udp->increaseAllocatedMemory(16);
}

int StackVirtual::getTotalTCPFlows() const { return flow_cache_tcp_vir_->getTotalFlows(); }

int StackVirtual::getTotalUDPFlows() const { return flow_cache_udp_vir_->getTotalFlows(); }

void StackVirtual::setFlowsTimeout(int timeout) {

        flow_table_udp_vir_->setTimeout(timeout);
        flow_table_tcp_vir_->setTimeout(timeout);
}

void StackVirtual::setTCPRegexManager(const SharedPointer<RegexManager> &rm) {

	tcp_vir_->setRegexManager(rm);
	tcp_generic->setRegexManager(rm);
	super_::setTCPRegexManager(rm);
}

void StackVirtual::setUDPRegexManager(const SharedPointer<RegexManager> &rm) {

	udp_vir_->setRegexManager(rm);
	udp_generic->setRegexManager(rm);
	super_::setUDPRegexManager(rm);
}

void StackVirtual::setTCPIPSetManager(const SharedPointer<IPSetManager> &ipset_mng) {

	tcp_vir_->setIPSetManager(ipset_mng);
	super_::setTCPIPSetManager(ipset_mng);
}

void StackVirtual::setUDPIPSetManager(const SharedPointer<IPSetManager> &ipset_mng) {

	udp_vir_->setIPSetManager(ipset_mng);
	super_::setUDPIPSetManager(ipset_mng);
}

#if defined(JAVA_BINDING)

void StackVirtual::setTCPRegexManager(RegexManager *sig) {

        SharedPointer<RegexManager> rm;

        if (sig != nullptr)
                rm.reset(sig);

        setTCPRegexManager(rm);
}

void StackVirtual::setUDPRegexManager(RegexManager *sig) {

        SharedPointer<RegexManager> rm;

        if (sig != nullptr)
                rm.reset(sig);

        setUDPRegexManager(rm);
}

void StackVirtual::setTCPIPSetManager(IPSetManager *ipset_mng) {

        SharedPointer<IPSetManager> im;

        if (ipset_mng != nullptr)
                im.reset(ipset_mng);

        setTCPIPSetManager(im);
}

void StackVirtual::setUDPIPSetManager(IPSetManager *ipset_mng) {

        SharedPointer<IPSetManager> im;

        if (ipset_mng != nullptr)
                im.reset(ipset_mng);

        setUDPIPSetManager(im);
}

#endif

std::tuple<Flow*, Flow*> StackVirtual::getCurrentFlows() const {

	// The low flow could be nullptr on gre encapsulations
        Flow *low_flow = udp_->getCurrentFlow();
        Flow *high_flow = nullptr;
	uint16_t proto = ip_vir_->getProtocol();;

        if (proto == IPPROTO_TCP)
        	high_flow = tcp_vir_->getCurrentFlow();
        else if (proto == IPPROTO_UDP)
        	high_flow = udp_vir_->getCurrentFlow();

#if GCC_VERSION < 50500
        return std::tuple<Flow*, Flow*>(low_flow, high_flow);
#else
        return {low_flow, high_flow};
#endif
}

SharedPointer<Flow> StackVirtual::getFlow(const FlowSearchOptions &fsos) const {

        if (fsos.protocol == IPPROTO_UDP) {
                SharedPointer<Flow> flow = flow_table_udp_->find(fsos);
                if (!flow)
                        flow = flow_table_udp_vir_->find(fsos);

                return flow;
        } else
                return flow_table_tcp_vir_->find(fsos);
}

} // namespace aiengine
